# --------------------------
# -*- coding=utf-8 -*-
#
# written by zhaoxingjie
# 18829350080@163.com
# --------------------------
import os

from sklearn.model_selection import train_test_split


def write_txt(txt_path, images):
    with open(txt_path, 'w') as f:
        for file in images:
            f.write(file[0:-4] + '\n')
            print(file)


img_path = '/media/hp208/4t/zhaoxingjie/project_graduation/data/test-dev-voc-format-ir-3cls/VOC2007/JPEGImages/'
train_txt_path = '/media/hp208/4t/zhaoxingjie/project_graduation/data/test-dev-voc-format-ir-3cls/VOC2007/ImageSets/Main/trainval.txt'
val_txt_path = '/media/hp208/4t/zhaoxingjie/project_graduation/data/test-dev-voc-format-ir-3cls/VOC2007/ImageSets/Main/val.txt'
test_txt_path = '/media/hp208/4t/zhaoxingjie/project_graduation/data/test-dev-voc-format-ir-3cls/VOC2007/ImageSets/Main/test.txt'

all_images = os.listdir(img_path)
trainval_images, test_images = train_test_split(all_images, train_size=0.8, test_size=0.2, shuffle=True)
_, val_images = train_test_split(trainval_images, train_size=0.9, test_size=0.1, shuffle=True)

write_txt(train_txt_path, trainval_images)
write_txt(val_txt_path, val_images)
write_txt(test_txt_path, test_images)
